Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 24, 2026

Description

This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a grouped linear operation
  • Add a post-scaled SwiGLU op and add support for interleaving SwiGLU gate and linear units
  • Add a fused operation for grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 30 commits January 7, 2026 00:15
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

timmoon10 and others added 4 commits February 5, 2026 02:18
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +290 to +292
quantizer=fc2_input_quantizers[group_idx],
requires_grad=False,
with_gemm_swizzled_scales=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect grad-required flags

In ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.fuser_forward, swiglu_ctx.input_requires_grad and swiglu_ctx.extra_input_requires_grad are set to True unconditionally (and input_requires_grad is set to requires_grad unconditionally). This will make ScaledSwiGLU.fuser_backward compute grad_input and grad_extra_input even when neither input_ nor scales require grads, which violates autograd semantics and can raise (e.g., scales.detach() passed into the fused kernel, but extra_input_requires_grad=True forces a gradient).

This should be set based on the actual requirements:

  • input_requires_grad = input_.requires_grad
  • swiglu_ctx.extra_input_requires_grad = scales.requires_grad
  • and for FC weights, check each parameter’s requires_grad (not just weight0).

Comment on lines +420 to +460
# Return immediately if fused kernel is not supported
if not BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
return ops

# Check if recipe is supported
if recipe is None:
return ops
if not recipe.mxfp8():
return ops

# Scan through ops, fusing if possible
out = []
window, ops = ops[:3], ops[3:]
while len(window) == 3:

# Check if window matches pattern
matches_pattern = True
if not (
isinstance(window[0], GroupedLinear)
and isinstance(window[1], ScaledSwiGLU)
and isinstance(window[2], GroupedLinear)
):
matches_pattern = False
elif window[0].has_bias or window[2].has_bias:
matches_pattern = False
elif window[0].num_groups != window[2].num_groups:
matches_pattern = False
elif (
window[0].in_features % 256 != 0
or window[0].out_features % 256 != 0
or window[2].in_features % 256 != 0
or window[2].out_features % 256 != 0
):
matches_pattern = False
elif window[1].glu_interleave_size != 32:
matches_pattern = False

if matches_pattern:
# Construct fused op if window matches pattern
op = BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8(
fc1=window[0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broken fusion window scan

Both fuse_backward_ops and fuse_forward_ops have a window/shift loop that can drop or reorder ops when the pattern doesn’t match. In the non-matching branch you do out.extend(window[:-2]); window = window[-2:] and then immediately do out.extend(window[:-3]) (which is a no-op for a 2-element window) before refilling. This causes the scan to advance by 1 op in some cases and by 2 in others, and it never emits window[-1] until the very end. For sequences like [A,B,C,D] where [A,B,C] doesn’t match but [B,C,D] would (or vice versa), this loop will not correctly consider all 3-op windows and can produce an incorrect fused op list.

This needs a standard sliding-window approach (advance by 1 when not matching; replace 3->1 when matching) to ensure no ops are skipped or duplicated.

Copy link
Collaborator Author

@timmoon10 timmoon10 Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your reading of the code is incorrect, although I would take suggestions for improving clarity. The loop condition guarantees the window size is 3 at the beginning of each iteration. We either fuse the window or eject the first op in the window, and then we refill back up to window size 3.

The only hint of an edge case I can see is if we perform a fusion, and that fused op can participate in further fusions. Then we might want to rewind the sliding window so that we reexamine the fused op in each window position. However, we know that the fused op is final, and we can safely advance the window past it.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants